{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Naive Bayes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>y</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>S</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>M</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>M</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>S</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>S</td>\n",
       "      <td>-1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   x1 x2  y\n",
       "0   1  S -1\n",
       "1   1  M -1\n",
       "2   1  M  1\n",
       "3   1  S  1\n",
       "4   1  S -1"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "### 构造数据集\n",
    "### 来自于李航统计学习方法表4.1\n",
    "x1 = [1,1,1,1,1,2,2,2,2,2,3,3,3,3,3]\n",
    "x2 = ['S','M','M','S','S','S','M','M','L','L','L','M','M','L','L']\n",
    "y = [-1,-1,1,1,-1,-1,-1,1,1,1,1,1,1,1,-1]\n",
    "\n",
    "df = pd.DataFrame({'x1':x1, 'x2':x2, 'y':y})\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = df[['x1', 'x2']]\n",
    "y = df[['y']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def nb_fit(X, y):\n",
    "    classes = y[y.columns[0]].unique()\n",
    "    class_count = y[y.columns[0]].value_counts()\n",
    "    class_prior = class_count/len(y)\n",
    "    \n",
    "    prior = dict()\n",
    "    for col in X.columns:\n",
    "        for j in classes:\n",
    "            p_x_y = X[(y==j).values][col].value_counts()\n",
    "            for i in p_x_y.index:\n",
    "                prior[(col, i, j)] = p_x_y[i]/class_count[j]\n",
    "    return classes, class_prior, prior"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{('x1', 1, -1): 0.5, ('x1', 2, -1): 0.3333333333333333, ('x1', 3, -1): 0.16666666666666666, ('x1', 3, 1): 0.4444444444444444, ('x1', 2, 1): 0.3333333333333333, ('x1', 1, 1): 0.2222222222222222, ('x2', 'S', -1): 0.5, ('x2', 'M', -1): 0.3333333333333333, ('x2', 'L', -1): 0.16666666666666666, ('x2', 'L', 1): 0.4444444444444444, ('x2', 'M', 1): 0.4444444444444444, ('x2', 'S', 1): 0.1111111111111111}\n"
     ]
    }
   ],
   "source": [
    "classes, class_prior, prior = nb_fit(X, y)\n",
    "print(classes, class_prior, prior)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_test = {'x1': 2, 'x2': 'S'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "classes, class_prior, prior = nb_fit(X, y)\n",
    "\n",
    "def predict(X_test):\n",
    "    res = []\n",
    "    for c in classes:\n",
    "        p_y = class_prior[c]\n",
    "        p_x_y = 1\n",
    "        for i in X_test.items():\n",
    "            p_x_y *= prior[tuple(list(i)+[c])]\n",
    "        res.append(p_y*p_x_y)\n",
    "    return classes[np.argmax(res)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "测试数据预测类别为： -1\n"
     ]
    }
   ],
   "source": [
    "print('测试数据预测类别为：', predict(X_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of GaussianNB in iris data test: 0.9466666666666667\n"
     ]
    }
   ],
   "source": [
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.naive_bayes import GaussianNB\n",
    "from sklearn.metrics import accuracy_score\n",
    "X, y = load_iris(return_X_y=True)\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)\n",
    "gnb = GaussianNB()\n",
    "y_pred = gnb.fit(X_train, y_train).predict(X_test)\n",
    "print(\"Accuracy of GaussianNB in iris data test:\", \n",
    "      accuracy_score(y_test, y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
